from __future__ import division, print_function, absolute_import

import numpy as np
from CDTools.tools.plotting import *
from CDTools.tools.analysis import *
from CDTools.tools.image_processing import *
from CDTools.tools.cmath import *
import CDTools
from matplotlib import pyplot as plt
import matplotlib as mpl
import pickle
from scipy import fftpack as ffts
import torch as t
from RPI_tools import *


#
# This file runs all the RIP analysis from the optical experiments.
#

#
# First, we load the three datasets we need
#
ptycho_filename = 'data/Optical_Data_ptycho.cxi'
ss_filename = 'data/Optical_Data_ss.cxi'
ss_defocused_filename = 'data/Optical_Data_ss_defocused.cxi'
ptycho_dataset = CDTools.datasets.Ptycho2DDataset.from_cxi(ptycho_filename)
ss_dataset = CDTools.datasets.Ptycho2DDataset.from_cxi(ss_filename)
ss_defocused_dataset = CDTools.datasets.Ptycho2DDataset.from_cxi(ss_defocused_filename)


# Then, we load the results of the calibration ptychography scan
with open('data/Optical_ptycho.pickle', 'rb') as f:
    ptycho_results = pickle.load(f)


# We extract the information we need from the calibrations and datasets
nearfield = ptycho_results['probe'][0]
background = ptycho_results['background']
basis = ptycho_results['basis']
wavelength = ss_defocused_dataset.wavelength

ss_intensities = ss_dataset[0][1]
ss_defocused_intensities = ss_defocused_dataset[0][1]

t_nearfield = complex_to_torch(nearfield).to(dtype=t.float32) 
t_background = t.Tensor(background).to(dtype=t.float32)
t_intensities = t.Tensor(ss_intensities).to(dtype=t.float32)
t_defocused_intensities = t.Tensor(ss_defocused_intensities).to(dtype=t.float32)


#
# Now, we perform the RIP Phase retrievals
#

# First, from the diffraction pattern extracted from the calibration scan
resolution = 400
retrieved, losses = reconstruct(t_intensities, t_nearfield, 2*(resolution //2), 0.4, 1000, background=t_background, optimizer='Adam', GPU=True, schedule=True, numtries=20)

retrieval = torch_to_complex(retrieved)

# Then, from the second pattern taken several hours later and defocused
defocuses = np.linspace(-3.6, -3.2, 11)*1e-3
defocus_retrieved, defocus_losses = reconstruct_defocus(t_defocused_intensities, t_nearfield, basis, wavelength, 2*(resolution//2), 0.4, 1000, defocuses, background=t_background, optimizer='Adam', GPU=True, schedule=True, numtries=20, verbose=True)

prop_retrieval = torch_to_complex(defocus_retrieved)

# Inspect the retrieved probe defocus
# plt.figure()
# plt.plot(defocuses, defocus_losses)
# plt.show()
# exit()

# Now we load some information to help with analysis
basis = ptycho_results['basis']
probe = ptycho_results['probe'][0]
ptycho_obj = ptycho_results['obj']


pixel_ratio = ss_intensities.shape[0] / retrieval.shape[0]


# This defines the regions of interest in the ptycho and single shot images
#fov = 200
#offset1 = [110,115]
#offset2 = [915,795]
#offset3 = [95,95]
fov = 240
offset1 = [90,95]
offset2 = [880,785]
offset3 = [75,75]

big_fov = int(pixel_ratio*fov)

pad = (big_fov - fov)//2

big_xs = np.arange(big_fov)
big_Xs, big_Ys = np.meshgrid(big_xs, big_xs, indexing='ij')
big_Xhann = np.sin(np.pi*big_Xs/(big_fov-1))**2
big_Yhann = np.sin(np.pi*big_Ys/(big_fov-1))**2
big_window = (big_Xhann * big_Yhann).astype(np.float32)

lil_xs = np.arange(fov)
lil_Xs, lil_Ys = np.meshgrid(lil_xs, lil_xs, indexing='ij')
lil_Xhann = np.sin(np.pi*lil_Xs/(fov-1))**2
lil_Yhann = np.sin(np.pi*lil_Ys/(fov-1))**2
lil_window = (lil_Xhann * lil_Yhann).astype(np.float32)

#plt.close('all')
#plt.imshow(big_window)
#plt.figure()
#plt.imshow(lil_window)
#plt.show()

# Now we extract the images and downsample the ptychography
im1 = lil_window * retrieval[offset1[0]:offset1[0]+fov,offset1[1]:offset1[1]+fov].copy()
im2 = big_window * ptycho_obj[offset2[0]:offset2[0]+big_fov,offset2[1]:offset2[1]+big_fov].copy()
im3 = lil_window * prop_retrieval[offset3[0]:offset3[0]+fov,offset3[1]:offset3[1]+fov].copy()
im3_y, im3_x = np.mgrid[:im3.shape[0],:im3.shape[1]]
# The following line corrects for the phase ramp in the defocused reconstruction
# by hand. This comes about because the detector was moved slightly between
# reconstructions, changing the location of the zero-frequency pixel
im3 *= np.exp(2j*np.pi*im3_x / 75) # correct for phase ramp by hand
im2_fft = ffts.fftshift(np.fft.fft2(ffts.ifftshift(im2)))
im2 = ffts.fftshift(np.fft.ifft2(ffts.ifftshift(im2_fft[pad:pad+fov,
                                                       pad:pad+fov])))


# We find the correct overlaps between the images
shift = find_shift(complex_to_torch(im1),complex_to_torch(im2))
shift3 = find_shift(complex_to_torch(im3),complex_to_torch(im2))
#im1 = np.roll(im1, -shift, axis=(0,1))


# And shift them to overlap correctly
im21 = torch_to_complex(sinc_subpixel_shift(complex_to_torch(im2), shift))
im1 = im1[10:-10,10:-10]
im21 = im21[10:-10,10:-10]

im23 = torch_to_complex(sinc_subpixel_shift(complex_to_torch(im2), shift3))
im3 = im3[10:-10,10:-10]
im23 = im23[10:-10,10:-10]


# Nowe we calculate the FRCS
big_basis = pixel_ratio * basis
freqs, frc, threshold = calc_frc(im1,im21, big_basis, im_slice=np.s_[:,:], snr=1, nbins=25,window='None')
freqs3, frc3, threshold3 = calc_frc(im3,im23, big_basis, im_slice=np.s_[:,:], snr=1, nbins=25, window='None')


# And plot the relevant results
plt.figure(figsize=(3.5,3))
plt.plot(freqs/1e3, frc, label='Ideal')
plt.plot(freqs/1e3, frc3, label='Defocused')
plt.plot(freqs/1e3, threshold, label='Threshold')
plt.xlabel('Spatial Frequency (cycles per mm)')
plt.ylabel('FRC')
plt.legend()
plt.tight_layout()

    
plot_amplitude(probe, basis=basis)
plot_colorized(probe, basis=basis)
plot_amplitude(ffts.fftshift(np.fft.fft2(ffts.ifftshift(probe))))
plot_amplitude(ptycho_obj, basis=basis)
plot_amplitude(retrieval)
plot_amplitude(prop_retrieval)
plot_amplitude(np.sqrt(ss_intensities.numpy()))
plt.show()
